import torch
import torch.nn as nn
from torchvision import models,transforms

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.vgg = models.vgg16(pretrained=True)
        # 删除最后的全连接层
        self.vgg.classifier = self.vgg.classifier[:-1]
        
    def forward(self, img1, img2):
        out1 = self.vgg(img1)
        out2 = self.vgg(img2)
        # 返回两个图像的特征向量
        return out1, out2